{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {},
  "cells": [
    {
      "metadata": {},
      "source": [
        "<td>\n",
        "   <a target=\"_blank\" href=\"https://labelbox.com\" ><img src=\"https://labelbox.com/blog/content/images/2021/02/logo-v4.svg\" width=256/></a>\n",
        "</td>"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "<td>\n",
        "<a href=\"https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/integrations/sam/meta_sam_labelbox.ipynb\" target=\"_blank\"><img\n",
        "src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"></a>\n",
        "</td>\n",
        "\n",
        "<td>\n",
        "<a href=\"https://github.com/Labelbox/labelbox-python/blob/master/examples/integrations/sam/meta_sam_labelbox.ipynb\" target=\"_blank\"><img\n",
        "src=\"https://img.shields.io/badge/GitHub-100000?logo=github&logoColor=white\" alt=\"GitHub\"></a>\n",
        "</td>"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "# Setup\n",
        "This notebook is used to show how to use Meta's Segment Anything model to create masks that can then be uploaded to a Labelbox project"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "!pip install -q \"labelbox[data]\"\n",
        "!pip install -q ultralytics==8.0.20\n",
        "!pip install -q 'git+https://github.com/facebookresearch/segment-anything.git'"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Check if in google colab\n",
        "try:\n",
        "  import google.colab\n",
        "  IN_COLAB = True\n",
        "except:\n",
        "  IN_COLAB = False"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "from IPython import display\n",
        "display.clear_output()\n",
        "\n",
        "import ultralytics\n",
        "ultralytics.checks()\n",
        "\n",
        "import cv2\n",
        "import numpy as np\n",
        "from ultralytics import YOLO\n",
        "from IPython.display import display, Image\n",
        "import torch\n",
        "import matplotlib.pyplot as plt\n",
        "from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor\n",
        "import os\n",
        "import urllib.request\n",
        "import uuid\n",
        "\n",
        "import labelbox as lb\n",
        "import labelbox.types as lb_types\n",
        "\n",
        "HOME = os.getcwd()\n",
        "\n",
        "if IN_COLAB:\n",
        "    from google.colab.patches import cv2_imshow"
      ],
      "cell_type": "code",
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Ultralytics YOLOv8.0.20 \ud83d\ude80 Python-3.10.13 torch-2.2.0 CPU\n",
            "Setup complete \u2705 (10 CPUs, 16.0 GB RAM, 251.9/460.4 GB disk)\n"
          ]
        }
      ],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# API Key and Client\n",
        "Provide a valid api key below in order to properly connect to the Labelbox Client."
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "# Add your api key\n",
        "API_KEY=\"\"\n",
        "# To get your API key go to: Workspace settings -> API -> Create API Key\n",
        "client = lb.Client(api_key=API_KEY)"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Predicting bounding boxes around common objects using YOLOv8\n",
        "\n",
        "First, we start with loading the YOLOv8 model, getting a sample image, and running the model on it to generate bounding boxes around some common objects."
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "### Utilize YOLOV8 to Create Bounding Boxes\n",
        "\n",
        "We use YOLOV8 in this demo to obtain bounding boxes around our images that we can later feed into SAM for our masks."
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "Below we run inference on a image using the YOLOv8 model."
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "# You can also use the Labelbox Client API to get specific images or an entire\n",
        "# dataset from your Catalog. Refer to these docs:\n",
        "# https://labelbox-python.readthedocs.io/en/latest/#labelbox.client.Client.get_data_row\n",
        "\n",
        "IMAGE_PATH = \"https://storage.googleapis.com/labelbox-datasets/image_sample_data/chairs.jpeg\""
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "model = YOLO(f'{HOME}/yolov8n.pt')\n",
        "results = model.predict(source=IMAGE_PATH, conf=0.25)\n",
        "\n",
        "# print(results[0].boxes.xyxy) # print bounding box coordinates\n",
        "\n",
        "# print(results[0].boxes.conf) # print confidence scores\n",
        "\n",
        "#for c in results[0].boxes.cls:\n",
        "# print(model.names[int(c)]) # print predicted classes"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "Below we visualize the bounding boxes on the image using CV2."
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "image_bgr = cv2.imread(\"./chairs.jpeg\")\n",
        "\n",
        "for box in results[0].boxes.xyxy:\n",
        "  cv2.rectangle(image_bgr, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 255, 0), 2)\n",
        "\n",
        "if IN_COLAB:\n",
        "  cv2_imshow(image_bgr)\n",
        "else:\n",
        "  cv2.imshow(\"demo\", image_bgr)\n",
        "  cv2.waitKey()"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "### Predicting segmentation masks using Meta's Segment Anything model\n",
        "\n",
        "Now we load Meta's Segment Anything model and feed the bounding boxes to it, so it can generate segmentation masks within them."
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "# Download SAM model weights\n",
        "\n",
        "CHECKPOINT_PATH = os.path.join(HOME, \"sam_vit_h_4b8939.pth\")\n",
        "\n",
        "if not os.path.isfile(CHECKPOINT_PATH):\n",
        "    req = urllib.request.urlretrieve(\"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth\", \"sam_vit_h_4b8939.pth\")\n"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
        "MODEL_TYPE = \"vit_h\""
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)\n",
        "mask_predictor = SamPredictor(sam)"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "transformed_boxes = mask_predictor.transform.apply_boxes_torch(results[0].boxes.xyxy, image_bgr.shape[:2])\n",
        "\n",
        "mask_predictor.set_image(image_bgr)\n",
        "\n",
        "masks, scores, logits = mask_predictor.predict_torch(\n",
        "    boxes = transformed_boxes,\n",
        "    multimask_output=False,\n",
        "    point_coords=None,\n",
        "    point_labels=None\n",
        ")\n",
        "masks = np.array(masks.cpu())\n",
        "\n",
        "# print(masks)\n",
        "# print(scores)"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "Here we visualize the segmentation masks drawn on the image."
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)\n",
        "\n",
        "final_mask = None\n",
        "for i in range(len(masks) - 1):\n",
        "  if final_mask is None:\n",
        "    final_mask = np.bitwise_or(masks[i][0], masks[i+1][0])\n",
        "  else:\n",
        "    final_mask = np.bitwise_or(final_mask, masks[i+1][0])\n",
        "\n",
        "plt.figure(figsize=(10, 10))\n",
        "plt.imshow(image_rgb)\n",
        "plt.axis('off')\n",
        "plt.imshow(final_mask, cmap='gray', alpha=0.7)\n",
        "\n",
        "plt.show()"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "### Uploading predicted segmentation masks with class names to Labelbox using Python SDK"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "# Create a Labelbox ObjectAnnotation of type mask for each predicted mask\n",
        "\n",
        "# Identifying what values in the numpy array correspond to the mask annotation\n",
        "color = (1, 1, 1)\n",
        "\n",
        "class_names = []\n",
        "for c in results[0].boxes.cls:\n",
        "  class_names.append(model.names[int(c)])\n",
        "\n",
        "annotations = []\n",
        "for idx, mask in enumerate(masks):\n",
        "  mask_data = lb_types.MaskData.from_2D_arr(np.asarray(mask[0], dtype=\"uint8\"))\n",
        "  mask_annotation = lb_types.ObjectAnnotation(\n",
        "    name = class_names[idx], # this is the class predicted in Step 1 (object detector)\n",
        "    value=lb_types.Mask(mask=mask_data, color=color),\n",
        "  )\n",
        "  annotations.append(mask_annotation)"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Create a new dataset\n",
        "\n",
        "# read more here: https://docs.labelbox.com/reference/data-row-global-keys\n",
        "global_key = \"my_unique_global_key\"\n",
        "\n",
        "test_img_url = {\n",
        "    \"row_data\": IMAGE_PATH,\n",
        "    \"global_key\": global_key\n",
        "}\n",
        "\n",
        "dataset = client.create_dataset(name=\"auto-mask-classification-dataset\")\n",
        "task = dataset.create_data_rows([test_img_url])\n",
        "task.wait_till_done()\n",
        "\n",
        "print(f\"Errors: {task.errors}\")\n",
        "print(f\"Failed data rows: {task.failed_data_rows}\")"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Create a new ontology if you don't have one\n",
        "\n",
        "# Add all unique classes detected in Step 1\n",
        "tools = []\n",
        "for name in set(class_names):\n",
        "  tools.append(lb.Tool(tool=lb.Tool.Type.RASTER_SEGMENTATION, name=name))\n",
        "\n",
        "ontology_builder = lb.OntologyBuilder(\n",
        "    classifications=[],\n",
        "    tools=tools\n",
        "  )\n",
        "\n",
        "ontology = client.create_ontology(\"auto-mask-classification-ontology\",\n",
        "                                  ontology_builder.asdict(),\n",
        "                                  media_type=lb.MediaType.Image\n",
        "                                  )\n",
        "\n",
        "# Or get an existing ontology by name or ID (uncomment one of the below)\n",
        "\n",
        "# ontology = client.get_ontologies(\"Demo Chair\").get_one()\n",
        "\n",
        "# ontology = client.get_ontology(\"clhee8kzt049v094h7stq7v25\")"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Create a new project if you don't have one\n",
        "\n",
        "# Project defaults to batch mode with benchmark quality settings if this argument is not provided\n",
        "# Queue mode will be deprecated once dataset mode is deprecated\n",
        "project = client.create_project(name=\"auto-mask-classification-project\",\n",
        "                                media_type=lb.MediaType.Image\n",
        "                                )\n",
        "\n",
        "# Or get an existing project by ID (uncomment the below)\n",
        "\n",
        "# project = get_project(\"fill_in_project_id\")\n",
        "\n",
        "# If the project already has an ontology set up, comment out this line\n",
        "project.setup_editor(ontology)"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Create a new batch of data for the project you specified above\n",
        "\n",
        "data_row_ids = client.get_data_row_ids_for_global_keys([global_key])['results']\n",
        "\n",
        "batch = project.create_batch(\n",
        "    \"auto-mask-classification-batch\",  # each batch in a project must have a unique name\n",
        "    data_rows=data_row_ids,\n",
        "    \n",
        "    # you can also specify global_keys instead of data_rows\n",
        "    #global_keys=[global_key],  # paginated collection of data row objects, list of data row ids or global keys\n",
        "    \n",
        "    priority=1  # priority between 1(highest) - 5(lowest)\n",
        ")\n",
        "\n",
        "print(f\"Batch: {batch}\")"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "labels = []\n",
        "labels.append(\n",
        "    lb_types.Label(data=lb_types.ImageData(global_key=global_key),\n",
        "                   annotations=annotations))"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Upload the predictions to your specified project and data rows as pre-labels\n",
        "\n",
        "upload_job = lb.MALPredictionImport.create_from_objects(\n",
        "    client=client,\n",
        "    project_id=project.uid,\n",
        "    name=\"mal_job\" + str(uuid.uuid4()),\n",
        "    predictions=labels\n",
        ")\n",
        "upload_job.wait_until_done()\n",
        "\n",
        "print(f\"Errors: {upload_job.errors}\", )\n",
        "print(f\"Status of uploads: {upload_job.statuses}\")"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "### Cleanup"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "#dataset.delete()\n",
        "#project.delete()"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    }
  ]
}